# ======================================================================
# Copyright CERFACS (March 2019)
# Contributor: Adrien Suau (adrien.suau@cerfacs.fr)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================

from qat.lang.AQASM.gates import X, CNOT, CCNOT
from qat.lang.AQASM.routines import QRoutine
from qat.lang.AQASM.misc import build_gate
from qat.lang.AQASM.gates import AbstractGate

from qaths.routines.multiply_controlled_toffoli import n_controlled_x


@build_gate("_maj", [], arity=3)
def _maj() -> QRoutine:
    rout = QRoutine(arity=3)
    rout.apply(CNOT, 2, 1)
    rout.apply(CNOT, 2, 0)
    rout.apply(CCNOT, 0, 1, 2)
    return rout


@build_gate("compare_qubit", [int], arity=lambda n: 2 * n + 1)
def qubit_compare_high_bit_maj(n: int) -> QRoutine:
    r"""Compare the two given quantum registers.

    The returned routine needs :math:`2n + 1` qubits organised as follow: ::

        |        |a>        |        |b>        |  |res>  |
        |   .   .   .   .   |   .   .   .   .   |    .    |
        |   0          n-1  |   n         2*n-1 |   2*n   |

    The comparator use the fact that the high bit of (a - b) is 1 if and
    only if:

    1. The subtraction is done on unsigned integer
    2. Underflow result in cycling to the highest value (i.e. (0 - 1) = (2^n - 1))

    It also use the identity ::

        a - b = (a' + b)'

    where ' is the bitwise complementation.

    :param n: The number of qubits representing a and b.
    :return: a routine setting the res qubit to 1 if a < b.
    """
    rout = QRoutine()
    ic = rout.new_wires(0)
    rout.set_ancillae(ic)
    a = rout.new_wires(n)
    b = rout.new_wires(n)
    res = rout.new_wires(1)

    # Complementation of a
    for i in range(n):
        rout.protected_apply(X, a[i])
    # End of complementation of a

    # High-bit computation
    rout.protected_apply(_maj(), ic, b[0], a[0])
    for i in range(1, n):
        rout.protected_apply(_maj(), a[i - 1], b[i], a[i])
    rout.apply(CNOT, a[n - 1], res)
    for i in range(n - 1, 0, -1):
        rout.protected_apply(_maj().dag(), a[i - 1], b[i], a[i])
    rout.protected_apply(_maj().dag(), ic, b[0], a[0])
    # End of high-bit computation

    # Inverse complementation of a
    for i in range(n):
        rout.protected_apply(X, a[i])
    # End of inverse complementation of a

    return rout


@build_gate("compare_qubit", [int], arity=lambda n: 2 * n + 1)
def qubit_compare_thapliyal(n: int) -> QRoutine:
    r"""Compare the two given quantum registers.

    The returned routine needs :math:`2n + 1` qubits organised as follow: ::

        |        |a>        |        |b>        |  |res>  |
        |   .   .   .   .   |   .   .   .   .   |    .    |
        |   0          n-1  |   n         2*n-1 |   2*n   |

    The comparator use the fact that the high bit of (a - rhs) is 1 if and
    only if:

    1. The subtraction is done on unsigned integer
    2. Underflow result in cycling to the highest value (i.e. (0 - 1) = (2^n - 1))

    It also use the identity ::

        a - rhs = (a' + rhs)'

    where ' is the bitwise complementation.

    In other words, if (a' + rhs) overflows then a < b.

    :param n: The number of qubits representing a.
    :return: a routine setting the res qubit to 1 if a < b.
    """
    rout = QRoutine()
    # Label each qubit for clearer code.
    a = rout.new_wires(n)
    b = rout.new_wires(n)
    res = rout.new_wires(1)

    # For better code, we make a[n] = res as res is supposed to be a kind of
    # output carry
    a.append(res)

    # Complementation of a
    for i in range(n):
        rout.protected_apply(X, a[i])
    # End of complementation of a

    # Following now the steps from the paper https://arxiv.org/pdf/1712.02630.pdf
    # Step 1
    for i in range(n):
        rout.protected_apply(CNOT, a[i], b[i])
    # Step 2
    for i in reversed(range(1, n)):
        rout.protected_apply(CNOT, a[i], a[i + 1])
    # Step 3
    for i in range(0, n - 1):
        rout.protected_apply(CCNOT, b[i], a[i], a[i + 1])
    # Compute the result
    rout.protected_apply(CCNOT, b[n - 1], a[n - 1], a[n])
    # Start modifications here. The output carry (i.e. the result of the comparison)
    # have been computed with the last gate applied. We do not need to continue the
    # computations described in the paper for a full adder, as we only need a
    # comparator. As the result has been computed, we just have to uncompute all the
    # previous operations except the one computing the result and exit.

    # Step 3
    for i in reversed(range(0, n - 1)):
        rout.protected_apply(CCNOT.dag(), b[i], a[i], a[i + 1])
    # Step 2
    for i in range(1, n):
        rout.protected_apply(CNOT.dag(), a[i], a[i + 1])
    # Step 1
    for i in reversed(range(n)):
        rout.protected_apply(CNOT.dag(), a[i], b[i])
    # Inverse complementation of a
    for i in range(n):
        rout.protected_apply(X.dag(), a[i])
    # End of inverse complementation of a

    return rout


@build_gate("compare_const", [int, int], arity=lambda n, rhs: n + 1)
def compare_adder(n: int, rhs: int) -> QRoutine:
    r"""Compare the given quantum register with rhs.

    The returned routine needs :math:`n + 2` qubits organised as follow: ::

        |        |a>        |  |res>  |
        |   .   .   .   .   |    .    |
        |   0          n-1  |   n+1   |

    The comparator use the fact that the high bit of (a - rhs) is 1 if and
    only if:

    1. The subtraction is done on unsigned integer
    2. Underflow result in cycling to the highest value (i.e. (0 - 1) = (2^n - 1))

    It also use the identity ::

        a - rhs = (a' + rhs)'

    where ' is the bitwise complementation.

    :param n: The number of qubits representing a.
    :param rhs: The number to compare a to.
    :return: a routine setting the res qubit to 1 if a < rhs.
    """
    rout = QRoutine()
    a = rout.new_wires(n)
    oc = rout.new_wires(1)
    rout.set_ancillae(oc)
    res = rout.new_wires(1)

    adder = AbstractGate("add_const", [int, int], arity=lambda q_size, rhs: q_size)

    # Complementation of a
    for i in range(n):
        rout.protected_apply(X, a[i])
    # End of complementation of a

    rout.protected_apply(adder(n + 1, rhs), oc, a)
    rout.apply(CNOT, oc, res)
    rout.protected_apply(adder(n + 1, rhs).dag(), oc, a)

    # Inverse complementation of a
    for i in range(n):
        rout.protected_apply(X, a[i])
    # End of inverse complementation of a

    return rout


@build_gate("compare_const", [int, int], arity=lambda n, rhs: n + 1)
def arithmetic_compare(n: int, rhs: int) -> QRoutine:
    r"""Compare the given quantum register with rhs.

    The returned routine needs :math:`2n` qubits organised as follow: ::

        |       |lhs>       |  |res>  |      |dirty>      |
        |   .   .   .   .   |    .    |   .   .   .   .   |
        |   0          n-1  |    n    |  n+1        2*n-1 |

    The :math:`\ket{\text{dirty}}` quantum register is a little special in this case
    as it does not necessarily need to be initialised to :math:`\ket{0}`. The
    :math:`\ket{\text{dirty}}` quantum register can be given in an arbitrary
    state and will be returned in the exact same state.

    The comparator use the fact that the high bit of (a - b) is 1 if a < b.

    It also use the identity ::

        a - b = (a' + b)'

    where ' is the bitwise complementation.

    NOTE: the |dirty> register is currently used as an ancilla. This is non-optimal,
    but I am waiting for the support of "borrowed" or "dirty" qubits in qat library.

    :param n: The number of qubits representing a.
    :param rhs: The number to compare a to.
    :return: a routine setting the res qubit to 1 if a < b.
    """
    rout = QRoutine()
    lhs = rout.new_wires(n)
    res = rout.new_wires(1)
    dirty = rout.new_wires(n - 1)
    rout.set_ancillae(dirty)

    high_bit_compute = AbstractGate(
        "high_bit_compute", [int, int], arity=lambda n, rhs: 2 * n
    )

    # Complementation of lhs
    for i in range(n):
        rout.protected_apply(X, lhs[i])
    # End of complementation of lhs

    # High-bit computation
    rout.apply(high_bit_compute(n, rhs), lhs, res, dirty)
    # End of high-bit computation

    # Inverse complementation of a
    for i in range(n):
        rout.protected_apply(X.dag(), lhs[i])
    # End of inverse complementation of a

    return rout


@build_gate("compare_range_const", [int, int, int])
def range_compare(n: int, lrhs: int, rrhs: int) -> QRoutine:
    r"""Compare the given qubit with one value in the range rhss.

    The returned routine needs as many qubits as the underlying comparator used.

    This method can be used when we need to compare a quantum register to a
    given integer and we do not need to have a valid behaviour for some values
    directly after the given integer.

    Thanks to the different restrictions imposed on the output of this compare
    method, more optimisations can be performed.

    This function **is deterministic** (two calls with the same parameters will
    return the same QRoutine).

    This function will call the function :func:`~compare` with the parameters
    `n` and `rhs` for `rhs` a value in the given [lrhs, rrhs] range.

    .. admonition:: Example

       Imagine you need to compare a quantum register :math:`\ket{x}` to
       :math:`4` and that you do not care about the result between 4 and 7, i.e.:

       #. the result of comp(x, 4) should be valid (i.e. True) for :math:`x < 4`.
       #. the result of comp(x, 4) does not need to be valid for
          :math:`4 \leqslant x < 7`.
       #. the result of comp(x, 4) should be valid (i.e. False) for
          :math:`7 \leqslant x`.

    This kind of situation arise in the construction of oracles when some rows are
    empty (no non-zero entry). When a row is empty in oracle construction, the
    only oracle that should take care of this row is the weight oracle. All the
    other oracles can return any value they want.

    :param n: The number of qubits representing a.
    :param lrhs: The lower bound for the rhs values we do not care about.
    :param rrhs: The upper bound for the rhs values we do not care about.
    :return: a optimised routine setting the res qubit to 1 if a < rhs.
    """
    # Unpack
    low, up = lrhs, rrhs

    comparator = AbstractGate(
        "compare_const", [int, int], arity=lambda q_size, rhs: q_size + 1
    )

    # Search for the integer that has the less '1' in binary as this will
    # optimise the number of quantum gates used by compare.
    def get_number_of_ones_in_binary(i: int) -> int:
        return bin(i).count("1")

    number_of_ones, optimum_integer = min(
        (get_number_of_ones_in_binary(i), i) for i in range(low, up)
    )

    return comparator(n, optimum_integer)


@build_gate("equals_const", [int, int], arity=lambda n, rhs: n + 1)
def equals_comparator(n: int, rhs: int) -> QRoutine:
    r"""Compare the given quantum register with rhs.

    The returned routine needs :math:`n + 1` qubits organised as follow: ::

        |        |a>        |  |res>  |
        |   .   .   .   .   |    .    |
        |   0          n-1  |    n    |

    :param n: The number of qubits representing a.
    :param rhs: The number to compare a to.
    :return: a routine setting the res qubit to 1 if a == rhs.
    """
    rout = QRoutine()
    a = rout.new_wires(n)
    res = rout.new_wires(1)
    ancillas = rout.new_wires(3)
    rout.set_ancillae(ancillas)

    comparator = AbstractGate(
        "compare_const", [int, int], arity=lambda q_size, rhs: q_size + 1
    )

    # |ancillas[0]>  :=  |a>  >=  rhs
    rout.protected_apply(comparator(n, rhs), a, ancillas[2], ancillas[0])
    rout.protected_apply(X, ancillas[0])
    # |ancillas[1]>  :=  |a>  <  rhs+1
    rout.protected_apply(comparator(n, rhs + 1), a, ancillas[2], ancillas[1])

    # |res>  :=  ( |ancillas[0]>  &&  |ancillas[1]> )
    rout.apply(CCNOT, ancillas[0], ancillas[1], res)

    # Uncompute
    rout.protected_apply(comparator(n, rhs + 1).dag(), a, ancillas[2], ancillas[1])
    rout.protected_apply(X.dag(), ancillas[0])
    rout.protected_apply(comparator(n, rhs).dag(), a, ancillas[2], ancillas[0])

    return rout


@build_gate("equals_const", [int, int], arity=lambda n, rhs: n + 1)
def equals_toffoli(n: int, rhs: int) -> QRoutine:
    r"""Compare the given quantum register with rhs.

    The returned routine needs :math:`n + 2` qubits organised as follow: ::

        |        |a>        |  |res>  | |ancilla> |
        |   .   .   .   .   |    .    |     .     |
        |   0          n-1  |    n    |    n+1    |

    The equality tester uses a multiply-controlled CNOT.

    :param n: The number of qubits representing a.
    :param rhs: The number to compare a to.
    :return: a routine setting the res qubit to 1 if a == rhs.
    """
    rout = QRoutine()
    controls = rout.new_wires(n)
    target = rout.new_wires(1)
    ancilla = rout.new_wires(1)
    rout.set_ancillae(ancilla)

    # The control register represents an unsigned integer in big-endian (see convention)
    # so we iterate over the bits of rhs in big-endian order.
    for i, bit in enumerate(bin(rhs)[2:].zfill(n)):
        # If the current bit of rhs is 0, then the qubit of controls should be |0> for
        # the equality test to return True. This means that we should negate the qubit
        # before giving it to the multiply-controlled Toffoli gate.
        if bit == "0":
            rout.protected_apply(X, controls[i])

    # Apply the multiply-controlled Toffoli gate
    rout.apply(n_controlled_x(n), controls, target, ancilla)

    # Revert the X gates
    for i, bit in enumerate(bin(rhs)[2:].zfill(n)):
        if bit == "0":
            rout.protected_apply(X.dag(), controls[i])

    return rout
